-- | This module helps parse the proto OpDef into a Haskell type which is more
-- descriptive of how the attributes and arguments will be used in the
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
    ( ParsedOp(..)
    , Name(..)
    , HaskellName(..)
    , TFName(..)
    , Attr(..)
    , AttrType(..)
    , AttrBaseType(..)
    , TypeParam(..)
    , ParsedArg(..)
    , ParsedArgCase(..)
    , ArgType(..)
    , ArgKind(..)
    , parseOp
    , camelCase
    ) where

import Data.Char (toUpper, toLower)
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Maybe (mapMaybe)
import Data.Ord (comparing)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Proto.Tensorflow.Core.Framework.AttrValue_Fields (list)
import Proto.Tensorflow.Core.Framework.OpDef
    ( OpDef
    , OpDef'ArgDef
    , OpDef'AttrDef
    )
import Proto.Tensorflow.Core.Framework.OpDef_Fields
    ( allowedValues
    , attr
    , maybe'defaultValue
    , description
    , name
    , inputArg
    , isRef
    , isStateful
    , outputArg
    , summary
    , typeListAttr
    , numberAttr
    , typeAttr
    , type'
    )

import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))

data ParsedOp = ParsedOp
    { parsedOpName :: Name
    , parsedOpSummary :: Text
    , parsedOpDescription :: Text
    , parsedInputs :: [ParsedArg]
    , parsedOutputs :: [ParsedArg]
    , explicitInputAttrs :: [Attr AttrType]
        -- ^ Attributes that must be set explicitly when creating the op.
        -- Associated with the type of the attribute.
    , inferredTypeAttrs :: [Attr TypeParam]
        -- ^ Attributes that are type parameters.
    , inferredListSizeAttrs :: [Attr (NonEmpty Name)]
        -- Attributes which are list sizes (ints) that are inferred automatically
        -- from one or more of the input tensors.
        -- Associated with the list of tensors whose size it describes.
    , parsedOpIsMonadic :: Bool
        -- ^ Whether this op is stateful or takes a stateful input.  Such ops
        -- should not be CSE'd and must be monadic in our API (i.e., return a
        -- Build action).
    }

data Name = Name
    { haskellName :: HaskellName
    , tfName :: TFName
    }

-- | A raw name as specified in the OpDef proto.
newtype TFName = TFName { unTFName :: Text }
    deriving (Eq, Ord)

-- | A name that's appropriate for a variable in a Haskell source file.
newtype HaskellName = HaskellName { unHaskellName :: Text }

-- | A named attribute, associated with some information about it.
data Attr a = Attr
    { attrName :: Name
    , attrDescription :: Text
    , attrInfo :: a
    }

-- | The type of an attribute.
data AttrType = AttrSingle AttrBaseType
                | AttrList AttrBaseType
                deriving Eq

data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
                | AttrType | AttrShape | AttrTensor | AttrFunc
                deriving Eq

data TypeParam = TypeParam
    { typeParamIsList :: Bool
    , typeParamRestrictions :: Maybe (NonEmpty DataType)
        -- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
        -- If 'Nothing', then any type is acceptable.
    }

-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
    { parsedArgName :: Name
    , parsedArgDescription :: Text
    , parsedArgCase :: ParsedArgCase
    }

data ParsedArgCase
    = SimpleArg { argType :: ArgType, argKind :: ArgKind }
    | ListArg
        { argLength :: Name  -- ^ The attribute that specifies this list's length.
        , argType :: ArgType
        , argKind :: ArgKind
        }
    | MixedListArg { argTypeAttr :: Name, argKind :: ArgKind }
        -- ^ A heterogeneous list.

maybeArgType :: ParsedArgCase -> Maybe ArgType
maybeArgType MixedListArg{} = Nothing
maybeArgType a = Just $ argType a

-- | The type of an argument.
data ArgType
    = ArgTypeFixed DataType -- ^ A fixed type.
    | ArgTypeAttr Name  -- ^ A type that depends on an attribute.

-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
    = ArgTensorRef -- Tensor Ref a
    | ArgTensorValue -- Tensor Value a
    | ArgTensorBuild -- Tensor Build a
    | ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
    deriving (Eq)

isRefCase :: ParsedArgCase -> Bool
isRefCase a
    | ArgTensorRef <- argKind a = True
    | Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True
    | otherwise = False

makeName :: Text -> Name
makeName n = Name
    { haskellName = HaskellName $ fixReservedName $ lowCase n
    , tfName = TFName n
    }

-- | Change a name so it doesn't conflict with any Haskell keywords.
fixReservedName :: Text -> Text
fixReservedName n
    | n `Set.member` reservedKeywords = n <> "'"
    | otherwise = n

reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
    -- Haskell2010 keywords:
    -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
    -- We don't include keywords that are allowed to be variable names,
    -- in particular: "as", "forall", and "hiding".
    [ "case"
    , "class"
    , "data"
    , "default"
    , "deriving"
    , "do"
    , "else"
    , "foreign"
    , "if"
    , "import"
    , "in"
    , "infix"
    , "infixl"
    , "infixr"
    , "instance"
    , "let"
    , "module"
    , "newtype"
    , "of"
    , "then"
    , "type"
    , "where"
    ]
    ++  -- Nonstandard extensions
    [ "mdo"   -- RecursiveDo
    , "rec"   -- Arrows, RecursiveDo
    , "proc"  -- Arrows
    ]

-- | Lower-case the given text.
lowCase :: Text -> Text
lowCase = forceCase toLower

forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
                      (Text.uncons s)

camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
                          $ Text.splitOn "_" s

-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper


parseOp :: OpDef -> ParsedOp
parseOp o = ParsedOp
    { parsedOpName = makeName $ o ^. name
    , parsedOpSummary = o ^. summary
    , parsedOpDescription = o ^. description
    , ..
    }
  where
    parsedOpIsMonadic = o ^. isStateful
                    || any (isRefCase . parsedArgCase) parsedInputs
                    || null (o ^. outputArg)
    parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a))
                                        tensorKindParams (o ^. inputArg) 
    tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]]
    parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a))
                        (o ^. outputArg)
    -- Integer attributes that can be inferred from the size of at least one
    -- input list.
    inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
                                $ o ^. attr
    implicitAttrs = Set.fromList $ map tfName $
                        map attrName inferredTypeAttrs
                            ++ map attrName inferredListSizeAttrs
    inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
    argTypeParams = Set.fromList $ map tfName $
                        mapMaybe (getArgTypeParam . parsedArgCase) $
                            parsedInputs ++ parsedOutputs
    -- Attributes that can't be inferred and don't have defaults, so must be
    -- passed as separate arguments to the op.
    explicitInputAttrs = sortBy (comparing (tfName . attrName))
                        $ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
                        $ o ^. attr

-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
inputTensorKind v a
    | a ^. isRef = ArgTensorRef
    | otherwise = ArgSomeTensor v

outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind isMonadic a
    | a ^. isRef = ArgTensorRef
    | isMonadic = ArgTensorValue
    | otherwise = ArgTensorBuild

getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a
    | TFName (a ^. name) `Set.notMember` implicitAttrs
    , a ^. maybe'defaultValue == Nothing
    , t <- parseAttrType o (a ^. type')
    , t `elem` map AttrSingle
                    [AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
                ++ [AttrList AttrType] = Just t
    | otherwise = Nothing

getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr argTypeParams a
    | TFName (a ^. name) `notElem` argTypeParams = Nothing
    | a ^. type' == "type" = Just $ TypeParam False allowed
    | a ^. type' == "list(type)" = Just $ TypeParam True allowed
    | otherwise = Nothing
  where
    allowed = nonEmpty (a ^. allowedValues . list . type')

getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
getArgTypeParam _ = Nothing

getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
    | a ^. type' == "int"
        = nonEmpty [t | ParsedArg { parsedArgName = t
                                  , parsedArgCase
                                        = ListArg { argLength = n }
                                  } <- inputs
                      , TFName (a ^. name) == tfName n]
    | otherwise = Nothing

-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do
                            x <- f a
                            Just Attr
                                { attrName = makeName (a ^. name)
                                , attrDescription = a ^. description
                                , attrInfo = x
                                }

parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
    { parsedArgName = makeName (a ^. name)
    , parsedArgDescription = a ^. description
    , parsedArgCase = parseArgCase a tKind
    }

parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a tKind
    | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
    | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
    | otherwise = SimpleArg thisArgType tKind
  where
    thisArgType
        | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
        | otherwise = ArgTypeFixed (a ^. type')
    maybeAttr :: Text -> Maybe Name
    maybeAttr "" = Nothing
    maybeAttr t = Just $ makeName t

parseAttrType :: OpDef -> Text -> AttrType
parseAttrType o = \case
    "string" -> AttrSingle AttrBytes
    "int" -> AttrSingle AttrInt64
    "float" -> AttrSingle AttrFloat
    "bool" -> AttrSingle AttrBool
    "type" -> AttrSingle AttrType
    "shape" -> AttrSingle AttrShape
    "tensor" -> AttrSingle AttrTensor
    "func" -> AttrSingle AttrFunc
    "list(string)" -> AttrList AttrBytes
    "list(int)" -> AttrList AttrInt64
    "list(float)" -> AttrList AttrFloat
    "list(bool)" -> AttrList AttrBool
    "list(type)" -> AttrList AttrType
    "list(shape)" -> AttrList AttrShape
    "list(tensor)" -> AttrList AttrTensor
    "list(func)" -> AttrList AttrFunc
    t -> error $ "parseAttrType: unrecognized type " ++ show t
              ++ " for op " ++ show (o ^. name)
